{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-tuning, Inference, and Evaluation with NVIDIA NeMo Microservices and NIM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook covers the following workflows:\n",
    "- Creating a dataset and uploading files for customizing and evaluating models\n",
    "- Running inference on base and customized models\n",
    "- Customizing and evaluating models, comparing metrics between base models and fine-tuned models\n",
    "- Running a safety check and evaluating a model using Guardrails\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prerequisites"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deploy NeMo Microservices\n",
    "Ensure the NeMo Microservices platform is up and running, including the model downloading step for `meta/llama-3.1-8b-instruct`. Please refer to the [installation guide](https://aire.gitlab-master-pages.nvidia.com/microservices/documentation/latest/nemo-microservices/latest-internal/set-up/deploy-as-platform/index.html) for instructions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can verify the `meta/llama-3.1-8b-instruct` is deployed by querying the NIM endpoint. The response should include a model with an `id` of `meta/llama-3.1-8b-instruct`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "# URL to NeMo deployment management service\n",
    "export NEMO_URL=\"http://nemo.test\"\n",
    "\n",
    "curl -X GET \"$NEMO_URL/v1/models\" \\\n",
    "  -H \"Accept: application/json\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up Developer Environment\n",
    "Set up your development environment on your machine. The project uses `uv` to manage Python dependencies. From the root of the project, install dependencies and create your virtual environment:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "uv sync --extra dev\n",
    "uv pip install -U llama-stack-client\n",
    "uv pip install -e .\n",
    "source .venv/bin/activate\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build Llama Stack Image\n",
    "Build the Llama Stack image using the virtual environment you just created. For local development, set `LLAMA_STACK_DIR` to ensure your local code is use in the image. To use the production version of `llama-stack`, omit `LLAMA_STACK_DIR`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```bash\n",
    "uv run --with llama-stack llama stack list-deps nvidia | xargs -L1 uv pip install\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Update the following variables in [config.py](./config.py) with your deployment URLs and API keys. The other variables are optional. You can update these to organize the resources created by this notebook.\n",
    "```python\n",
    "# (Required) NeMo Microservices URLs\n",
    "NDS_URL = \"\" # NeMo Data Store\n",
    "NEMO_URL = \"\" # Other NeMo Microservices (Customizer, Evaluator, Guardrails)\n",
    "NIM_URL = \"\" # NIM\n",
    "\n",
    "# (Required) Hugging Face Token\n",
    "HF_TOKEN = \"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. Set environment variables used by each service."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from docs.notebooks.nvidia.beginner_e2e.config import *\n",
    "\n",
    "# Metadata associated with Datasets and Customization Jobs\n",
    "os.environ[\"NVIDIA_DATASET_NAMESPACE\"] = NAMESPACE\n",
    "os.environ[\"NVIDIA_PROJECT_ID\"] = PROJECT_ID\n",
    "\n",
    "# Inference env vars\n",
    "os.environ[\"NVIDIA_BASE_URL\"] = NIM_URL\n",
    "\n",
    "# Data Store env vars\n",
    "os.environ[\"NVIDIA_DATASETS_URL\"] = NEMO_URL\n",
    "\n",
    "# Customizer env vars\n",
    "os.environ[\"NVIDIA_CUSTOMIZER_URL\"] = NEMO_URL\n",
    "os.environ[\"NVIDIA_OUTPUT_MODEL_DIR\"] = CUSTOMIZED_MODEL_DIR\n",
    "\n",
    "# Evaluator env vars\n",
    "os.environ[\"NVIDIA_EVALUATOR_URL\"] = NEMO_URL\n",
    "\n",
    "# Guardrails env vars\n",
    "os.environ[\"GUARDRAILS_SERVICE_URL\"] = NEMO_URL\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. Initialize the HuggingFace API client. Here, we use NeMo Data Store as the endpoint the client will invoke."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import HfApi\n",
    "import json\n",
    "import pprint\n",
    "import requests\n",
    "from time import sleep, time\n",
    "\n",
    "os.environ[\"HF_ENDPOINT\"] = f\"{NDS_URL}/v1/hf\"\n",
    "os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
    "\n",
    "hf_api = HfApi(endpoint=os.environ.get(\"HF_ENDPOINT\"), token=os.environ.get(\"HF_TOKEN\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4. Initialize the Llama Stack client using the NVIDIA provider."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n",
    "\n",
    "client =  LlamaStackAsLibraryClient(\"nvidia\")\n",
    "client.initialize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "5. Define a few helper functions we'll use later that wait for async jobs to complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_stack.apis.common.job_types import JobStatus\n",
    "\n",
    "def wait_customization_job(job_id: str, polling_interval: int = 30, timeout: int = 3600):\n",
    "    start_time = time()\n",
    "\n",
    "    response = client.post_training.job.status(job_uuid=job_id)\n",
    "    job_status = response.status\n",
    "\n",
    "    print(f\"Waiting for Customization job {job_id} to finish.\")\n",
    "    print(f\"Job status: {job_status} after {time() - start_time} seconds.\")\n",
    "\n",
    "    while job_status in [JobStatus.scheduled.value, JobStatus.in_progress.value]:\n",
    "        sleep(polling_interval)\n",
    "        response = client.post_training.job.status(job_uuid=job_id)\n",
    "        job_status = response.status\n",
    "\n",
    "        print(f\"Job status: {job_status} after {time() - start_time} seconds.\")\n",
    "\n",
    "        if time() - start_time > timeout:\n",
    "            raise RuntimeError(f\"Customization Job {job_id} took more than {timeout} seconds.\")\n",
    "        \n",
    "    return job_status\n",
    "\n",
    "def wait_eval_job(benchmark_id: str, job_id: str, polling_interval: int = 10, timeout: int = 6000):\n",
    "    start_time = time()\n",
    "    job_status = client.eval.jobs.status(benchmark_id=benchmark_id, job_id=job_id)\n",
    "\n",
    "    print(f\"Waiting for Evaluation job {job_id} to finish.\")\n",
    "    print(f\"Job status: {job_status} after {time() - start_time} seconds.\")\n",
    "\n",
    "    while job_status.status in [JobStatus.scheduled.value, JobStatus.in_progress.value]:\n",
    "        sleep(polling_interval)\n",
    "        job_status = client.eval.jobs.status(benchmark_id=benchmark_id, job_id=job_id)\n",
    "\n",
    "        print(f\"Job status: {job_status} after {time() - start_time} seconds.\")\n",
    "\n",
    "        if time() - start_time > timeout:\n",
    "            raise RuntimeError(f\"Evaluation Job {job_id} took more than {timeout} seconds.\")\n",
    "\n",
    "    return job_status\n",
    "\n",
    "# When creating a customized model, NIM asynchronously loads the model in its model registry.\n",
    "# After this, we can run inference on the new model. This helper function waits for NIM to pick up the new model.\n",
    "def wait_nim_loads_customized_model(model_id: str, polling_interval: int = 10, timeout: int = 300):\n",
    "    found = False\n",
    "    start_time = time()\n",
    "\n",
    "    print(f\"Checking if NIM has loaded customized model {model_id}.\")\n",
    "\n",
    "    while not found:\n",
    "        sleep(polling_interval)\n",
    "\n",
    "        response = requests.get(f\"{NIM_URL}/v1/models\")\n",
    "        if model_id in [model[\"id\"] for model in response.json()[\"data\"]]:\n",
    "            found = True\n",
    "            print(f\"Model {model_id} available after {time() - start_time} seconds.\")\n",
    "            break\n",
    "        else:\n",
    "            print(f\"Model {model_id} not available after {time() - start_time} seconds.\")\n",
    "\n",
    "    if not found:\n",
    "        raise RuntimeError(f\"Model {model_id} not available after {timeout} seconds.\")\n",
    "\n",
    "    assert found, f\"Could not find model {model_id} in the list of available models.\"\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload Dataset Using the HuggingFace Client"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start by creating a dataset with the `sample_squad_data` files. This data is pulled from the Stanford Question Answering Dataset (SQuAD) reading comprehension dataset, consisting of questions posed on a set of Wikipedia articles, where the answer to every question is a segment of text from the corresponding passage, or the question is unanswerable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_squad_dataset_name = \"sample-squad-test\"\n",
    "repo_id = f\"{NAMESPACE}/{sample_squad_dataset_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the repo\n",
    "response = hf_api.create_repo(repo_id, repo_type=\"dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Upload the files from the local folder\n",
    "hf_api.upload_folder(\n",
    "    folder_path=\"./sample_data/sample_squad_data/training\",\n",
    "    path_in_repo=\"training\",\n",
    "    repo_id=repo_id,\n",
    "    repo_type=\"dataset\",\n",
    ")\n",
    "hf_api.upload_folder(\n",
    "    folder_path=\"./sample_data/sample_squad_data/validation\",\n",
    "    path_in_repo=\"validation\",\n",
    "    repo_id=repo_id,\n",
    "    repo_type=\"dataset\",\n",
    ")\n",
    "hf_api.upload_folder(\n",
    "    folder_path=\"./sample_data/sample_squad_data/testing\",\n",
    "    path_in_repo=\"testing\",\n",
    "    repo_id=repo_id,\n",
    "    repo_type=\"dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataset\n",
    "response = client.datasets.register(\n",
    "    purpose=\"post-training/messages\",\n",
    "    dataset_id=sample_squad_dataset_name,\n",
    "    source={\n",
    "        \"type\": \"uri\",\n",
    "        \"uri\": f\"hf://datasets/{repo_id}\"\n",
    "    },\n",
    "    metadata={\n",
    "        \"format\": \"json\",\n",
    "        \"description\": \"Test sample_squad_data dataset for NVIDIA E2E notebook\",\n",
    "        \"provider_id\": \"nvidia\",\n",
    "    }\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the files URL\n",
    "response = requests.get(\n",
    "    url=f\"{NEMO_URL}/v1/datasets/{NAMESPACE}/{sample_squad_dataset_name}\",\n",
    ")\n",
    "assert response.status_code in (200, 201), f\"Status Code {response.status_code} Failed to fetch dataset {response.text}\"\n",
    "\n",
    "dataset_obj = response.json()\n",
    "print(\"Files URL:\", dataset_obj[\"files_url\"])\n",
    "assert dataset_obj[\"files_url\"] == f\"hf://datasets/{repo_id}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll use an entry from the `sample_squad_data` test data to verify we can run inference using NVIDIA NIM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pprint\n",
    "\n",
    "with open(\"./sample_data/sample_squad_data/testing/testing.jsonl\", \"r\") as f:\n",
    "    examples = [json.loads(line) for line in f]\n",
    "\n",
    "# Get the user prompt from the last example\n",
    "sample_prompt = examples[-1][\"prompt\"]\n",
    "pprint.pprint(sample_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test inference\n",
    "response = client.chat.completions.create(\n",
    "    messages=[\n",
    "        {\"role\": \"user\", \"content\": sample_prompt}\n",
    "    ],\n",
    "    model=BASE_MODEL,\n",
    "    max_tokens=20,\n",
    "    temperature=0.7,\n",
    ")\n",
    "print(f\"Inference response: {response.choices[0].message.content}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run an Evaluation, we'll first register a benchmark. A benchmark corresponds to an Evaluation Config in NeMo Evaluator, which contains the metadata to use when launching an Evaluation Job. Here, we'll create a benchmark that uses the testing file uploaded in the previous step. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_id = \"test-eval-config\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_eval_config = {\n",
    "    \"benchmark_id\": benchmark_id,\n",
    "    \"dataset_id\": \"\",\n",
    "    \"scoring_functions\": [],\n",
    "    \"metadata\": {\n",
    "        \"type\": \"custom\",\n",
    "        \"params\": {\"parallelism\": 8},\n",
    "        \"tasks\": {\n",
    "            \"qa\": {\n",
    "                \"type\": \"completion\",\n",
    "                \"params\": {\n",
    "                    \"template\": {\n",
    "                        \"prompt\": \"{{prompt}}\",\n",
    "                        \"max_tokens\": 20,\n",
    "                        \"temperature\": 0.7,\n",
    "                        \"top_p\": 0.9,\n",
    "                    },\n",
    "                },\n",
    "                \"dataset\": {\"files_url\": f\"hf://datasets/{repo_id}/testing/testing.jsonl\"},\n",
    "                \"metrics\": {\n",
    "                    \"bleu\": {\n",
    "                        \"type\": \"bleu\",\n",
    "                        \"params\": {\"references\": [\"{{ideal_response}}\"]},\n",
    "                    },\n",
    "                    \"string-check\": {\n",
    "                        \"type\": \"string-check\",\n",
    "                        \"params\": {\"check\": [\"{{ideal_response | trim}}\", \"equals\", \"{{output_text | trim}}\"]},\n",
    "                    },\n",
    "                },\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Register a benchmark, which creates an Evaluation Config\n",
    "response = client.benchmarks.register(\n",
    "    benchmark_id=benchmark_id,\n",
    "    dataset_id=repo_id,\n",
    "    scoring_functions=simple_eval_config[\"scoring_functions\"],\n",
    "    metadata=simple_eval_config[\"metadata\"]\n",
    ")\n",
    "print(f\"Created benchmark {benchmark_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Launch a simple evaluation with the benchmark\n",
    "response = client.eval.run_eval(\n",
    "    benchmark_id=benchmark_id,\n",
    "    benchmark_config={\n",
    "        \"eval_candidate\": {\n",
    "            \"type\": \"model\",\n",
    "            \"model\": BASE_MODEL,\n",
    "            \"sampling_params\": {}\n",
    "        }\n",
    "    }\n",
    ")\n",
    "job_id = response.model_dump()[\"job_id\"]\n",
    "print(f\"Created evaluation job {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for the job to complete\n",
    "job = wait_eval_job(benchmark_id=benchmark_id, job_id=job_id, polling_interval=5, timeout=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Job {job_id} status: {job.status}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_results = client.eval.jobs.retrieve(benchmark_id=benchmark_id, job_id=job_id)\n",
    "print(f\"Job results: {json.dumps(job_results.model_dump(), indent=2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract bleu score and assert it's within range\n",
    "initial_bleu_score = job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"bleu\"][\"scores\"][\"corpus\"][\"value\"]\n",
    "print(f\"Initial bleu score: {initial_bleu_score}\")\n",
    "\n",
    "assert initial_bleu_score >= 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract accuracy and assert it's within range\n",
    "initial_accuracy_score = job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"string-check\"][\"scores\"][\"string-check\"][\"value\"]\n",
    "print(f\"Initial accuracy: {initial_accuracy_score}\")\n",
    "\n",
    "assert initial_accuracy_score >= 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Customization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've established our baseline Evaluation metrics, we'll customize a model using our training data uploaded previously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start the customization job\n",
    "response = client.post_training.supervised_fine_tune(\n",
    "    job_uuid=\"\",\n",
    "    model=BASE_MODEL,\n",
    "    training_config={\n",
    "        \"n_epochs\": 2,\n",
    "        \"data_config\": {\n",
    "            \"batch_size\": 16,\n",
    "            \"dataset_id\": sample_squad_dataset_name,\n",
    "        },\n",
    "        \"optimizer_config\": {\n",
    "            \"lr\": 0.0001,\n",
    "        }\n",
    "    },\n",
    "    algorithm_config={\n",
    "        \"type\": \"LoRA\",\n",
    "        \"adapter_dim\": 16,\n",
    "        \"adapter_dropout\": 0.1,\n",
    "        \"alpha\": 16,\n",
    "        # NOTE: These fields are required, but not directly used by NVIDIA\n",
    "        \"rank\": 8,\n",
    "        \"lora_attn_modules\": [],\n",
    "        \"apply_lora_to_mlp\": True,\n",
    "        \"apply_lora_to_output\": False\n",
    "    },\n",
    "    hyperparam_search_config={},\n",
    "    logger_config={},\n",
    "    checkpoint_dir=\"\",\n",
    ")\n",
    "\n",
    "job_id = response.job_uuid\n",
    "print(f\"Created job with ID: {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for the job to complete\n",
    "job_status = wait_customization_job(job_id=job_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Job {job_id} status: {job_status}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the fine-tuning job succeeds, we can't immediately run inference on the customized model. In the background, NIM will load newly-created models and make them available for inference. This process typically takes < 5 minutes - here, we wait for our customized model to be picked up before attempting to run inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check that the customized model has been picked up by NIM;\n",
    "# We allow up to 5 minutes for the LoRA adapter to be loaded\n",
    "wait_nim_loads_customized_model(model_id=CUSTOMIZED_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, NIM can run inference on the customized model. However, to use the Llama Stack client to run inference, we need to explicitly register the model first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check that inference with the new model works\n",
    "from llama_stack.apis.models.models import ModelType\n",
    "\n",
    "# First, register the customized model\n",
    "client.models.register(\n",
    "    model_id=CUSTOMIZED_MODEL_DIR,\n",
    "    model_type=ModelType.llm,\n",
    "    provider_id=\"nvidia\",\n",
    ")\n",
    "\n",
    "response = client.completions.create(\n",
    "    prompt=\"Complete the sentence using one word: Roses are red, violets are \",\n",
    "    stream=False,\n",
    "    model=CUSTOMIZED_MODEL_DIR,\n",
    "    temperature=0.7,\n",
    "    top_p=0.9,\n",
    "    max_tokens=20,\n",
    ")\n",
    "print(f\"Inference response: {response.choices[0].text}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Customized Model\n",
    "Now that we've customized the model, let's run another Evaluation to compare its performance with the base model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Launch a simple evaluation with the same benchmark with the customized model\n",
    "response = client.eval.run_eval(\n",
    "    benchmark_id=benchmark_id,\n",
    "    benchmark_config={\n",
    "        \"eval_candidate\": {\n",
    "            \"type\": \"model\",\n",
    "            \"model\": CUSTOMIZED_MODEL_DIR,\n",
    "            \"sampling_params\": {}\n",
    "        }\n",
    "    }\n",
    ")\n",
    "job_id = response.model_dump()[\"job_id\"]\n",
    "print(f\"Created evaluation job {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for the job to complete\n",
    "customized_model_job = wait_eval_job(benchmark_id=benchmark_id, job_id=job_id, polling_interval=5, timeout=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Job {job_id} status: {customized_model_job.status}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "customized_model_job_results = client.eval.jobs.retrieve(benchmark_id=benchmark_id, job_id=job_id)\n",
    "print(f\"Job results: {json.dumps(customized_model_job_results.model_dump(), indent=2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract bleu score and assert it's within range\n",
    "customized_bleu_score = customized_model_job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"bleu\"][\"scores\"][\"corpus\"][\"value\"]\n",
    "print(f\"Customized bleu score: {customized_bleu_score}\")\n",
    "\n",
    "assert customized_bleu_score >= 35"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract accuracy and assert it's within range\n",
    "customized_accuracy_score = customized_model_job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"string-check\"][\"scores\"][\"string-check\"][\"value\"]\n",
    "print(f\"Initial accuracy: {customized_accuracy_score}\")\n",
    "\n",
    "assert customized_accuracy_score >= 0.45"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We expect to see an improvement in the bleu score and accuracy in the customized model's evaluation results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure the customized model evaluation is better than the original model evaluation\n",
    "print(f\"customized_bleu_score - initial_bleu_score: {customized_bleu_score - initial_bleu_score}\")\n",
    "assert (customized_bleu_score - initial_bleu_score) >= 27\n",
    "\n",
    "print(f\"customized_accuracy_score - initial_accuracy_score: {customized_accuracy_score - initial_accuracy_score}\")\n",
    "assert (customized_accuracy_score - initial_accuracy_score) >= 0.4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload Chat Dataset Using the HuggingFace Client\n",
    "Repeat the fine-tuning and evaluation workflow with a chat-style dataset, which has a list of `messages` instead of a `prompt` and `completion`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_squad_messages_dataset_name = \"test-squad-messages-dataset\"\n",
    "repo_id = f\"{NAMESPACE}/{sample_squad_messages_dataset_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the repo\n",
    "res = hf_api.create_repo(repo_id, repo_type=\"dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Upload the files from the local folder\n",
    "hf_api.upload_folder(\n",
    "    folder_path=\"./sample_data/sample_squad_messages/training\",\n",
    "    path_in_repo=\"training\",\n",
    "    repo_id=repo_id,\n",
    "    repo_type=\"dataset\",\n",
    ")\n",
    "hf_api.upload_folder(\n",
    "    folder_path=\"./sample_data/sample_squad_messages/validation\",\n",
    "    path_in_repo=\"validation\",\n",
    "    repo_id=repo_id,\n",
    "    repo_type=\"dataset\",\n",
    ")\n",
    "hf_api.upload_folder(\n",
    "    folder_path=\"./sample_data/sample_squad_messages/testing\",\n",
    "    path_in_repo=\"testing\",\n",
    "    repo_id=repo_id,\n",
    "    repo_type=\"dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the dataset\n",
    "response = client.datasets.register(\n",
    "    purpose=\"post-training/messages\",\n",
    "    dataset_id=sample_squad_messages_dataset_name,\n",
    "    source={\n",
    "        \"type\": \"uri\",\n",
    "        \"uri\": f\"hf://datasets/{repo_id}\"\n",
    "    },\n",
    "    metadata={\n",
    "        \"format\": \"json\",\n",
    "        \"description\": \"Test sample_squad_messages dataset for NVIDIA E2E notebook\",\n",
    "        \"provider_id\": \"nvidia\",\n",
    "    }\n",
    ")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the files URL\n",
    "response = requests.get(\n",
    "    url=f\"{NEMO_URL}/v1/datasets/{NAMESPACE}/{sample_squad_messages_dataset_name}\",\n",
    ")\n",
    "assert response.status_code in (200, 201), f\"Status Code {response.status_code} Failed to fetch dataset {response.text}\"\n",
    "dataset_obj = response.json()\n",
    "print(\"Files URL:\", dataset_obj[\"files_url\"])\n",
    "assert dataset_obj[\"files_url\"] == f\"hf://datasets/{repo_id}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference with chat/completions\n",
    "We'll use an entry from the `sample_squad_messages` test data to verify we can run inference using NVIDIA NIM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./sample_data/sample_squad_messages/testing/testing.jsonl\", \"r\") as f:\n",
    "    examples = [json.loads(line) for line in f]\n",
    "\n",
    "# get the user and assistant messages from the last example\n",
    "sample_messages = examples[-1][\"messages\"][:-1]\n",
    "pprint.pprint(sample_messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test inference\n",
    "response = client.chat.completions.create(\n",
    "    messages=sample_messages,\n",
    "    model=BASE_MODEL,\n",
    "    max_tokens=20,\n",
    "    temperature=0.7,\n",
    ")\n",
    "assert response.choices[0].message.content is not None\n",
    "print(f\"Inference response: {response.choices[0].message.content}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate with chat dataset\n",
    "We'll register a new benchmark that uses the chat-style testing file uploaded previously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_id = \"test-eval-config-chat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Register a benchmark, which creates an Eval Config\n",
    "simple_eval_config = {\n",
    "    \"benchmark_id\": benchmark_id,\n",
    "    \"dataset_id\": \"\",\n",
    "    \"scoring_functions\": [],\n",
    "    \"metadata\": {\n",
    "        \"type\": \"custom\",\n",
    "        \"params\": {\"parallelism\": 8},\n",
    "        \"tasks\": {\n",
    "            \"qa\": {\n",
    "                \"type\": \"completion\",\n",
    "                \"params\": {\n",
    "                    \"template\": {\n",
    "                        \"messages\": [\n",
    "                            {\"role\": \"{{item.messages[0].role}}\", \"content\": \"{{item.messages[0].content}}\"},\n",
    "                            {\"role\": \"{{item.messages[1].role}}\", \"content\": \"{{item.messages[1].content}}\"},\n",
    "                        ],\n",
    "                        \"max_tokens\": 20,\n",
    "                        \"temperature\": 0.7,\n",
    "                        \"top_p\": 0.9,\n",
    "                    },\n",
    "                },\n",
    "                \"dataset\": {\"files_url\": f\"hf://datasets/{repo_id}/testing/testing.jsonl\"},\n",
    "                \"metrics\": {\n",
    "                    \"bleu\": {\n",
    "                        \"type\": \"bleu\",\n",
    "                        \"params\": {\"references\": [\"{{item.messages[2].content | trim}}\"]},\n",
    "                    },\n",
    "                    \"string-check\": {\n",
    "                        \"type\": \"string-check\",\n",
    "                        \"params\": {\"check\": [\"{{item.messages[2].content}}\", \"equals\", \"{{output_text | trim}}\"]},\n",
    "                    },\n",
    "                },\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = client.benchmarks.register(\n",
    "    benchmark_id=benchmark_id,\n",
    "    dataset_id=repo_id,\n",
    "    scoring_functions=simple_eval_config[\"scoring_functions\"],\n",
    "    metadata=simple_eval_config[\"metadata\"]\n",
    ")\n",
    "print(f\"Created benchmark {benchmark_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Launch a simple evaluation with the benchmark\n",
    "response = client.eval.run_eval(\n",
    "    benchmark_id=benchmark_id,\n",
    "    benchmark_config={\n",
    "        \"eval_candidate\": {\n",
    "            \"type\": \"model\",\n",
    "            \"model\": BASE_MODEL,\n",
    "            \"sampling_params\": {}\n",
    "        }\n",
    "    }\n",
    ")\n",
    "job_id = response.model_dump()[\"job_id\"]\n",
    "print(f\"Created evaluation job {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for the job to complete\n",
    "job = wait_eval_job(benchmark_id=benchmark_id, job_id=job_id, polling_interval=5, timeout=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Job {job_id} status: {job.status}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_results = client.eval.jobs.retrieve(benchmark_id=benchmark_id, job_id=job_id)\n",
    "print(f\"Job results: {json.dumps(job_results.model_dump(), indent=2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract bleu score and assert it's within range\n",
    "initial_bleu_score = job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"bleu\"][\"scores\"][\"corpus\"][\"value\"]\n",
    "print(f\"Initial bleu score: {initial_bleu_score}\")\n",
    "\n",
    "assert initial_bleu_score >= 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract accuracy and assert it's within range\n",
    "initial_accuracy_score = job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"string-check\"][\"scores\"][\"string-check\"][\"value\"]\n",
    "print(f\"Initial accuracy: {initial_accuracy_score}\")\n",
    "\n",
    "assert initial_accuracy_score >= 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Customization with chat dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we've established our baseline Evaluation metrics for the chat-style dataset, we'll customize a model using our training data uploaded previously."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "customized_chat_model_name = \"test-messages-model\"\n",
    "customized_chat_model_version = \"v1\"\n",
    "customized_chat_model_dir = f\"{NAMESPACE}/{customized_chat_model_name}@{customized_chat_model_version}\"\n",
    "\n",
    "# NOTE: The output model name is derived from the environment variable. We need to re-initialize the client\n",
    "# here so the Post Training API picks up the updated value.\n",
    "os.environ[\"NVIDIA_OUTPUT_MODEL_DIR\"] = customized_chat_model_dir\n",
    "client.initialize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = client.post_training.supervised_fine_tune(\n",
    "    job_uuid=\"\",\n",
    "    model=BASE_MODEL,\n",
    "    training_config={\n",
    "        \"n_epochs\": 2,\n",
    "        \"data_config\": {\n",
    "            \"batch_size\": 16,\n",
    "            \"dataset_id\": sample_squad_messages_dataset_name,\n",
    "        },\n",
    "        \"optimizer_config\": {\n",
    "            \"lr\": 0.0001,\n",
    "        }\n",
    "    },\n",
    "    algorithm_config={\n",
    "        \"type\": \"LoRA\",\n",
    "        \"adapter_dim\": 16,\n",
    "        \"adapter_dropout\": 0.1,\n",
    "        \"alpha\": 16,\n",
    "        # NOTE: These fields are required, but not directly used by NVIDIA\n",
    "        \"rank\": 8,\n",
    "        \"lora_attn_modules\": [],\n",
    "        \"apply_lora_to_mlp\": True,\n",
    "        \"apply_lora_to_output\": False\n",
    "    },\n",
    "    hyperparam_search_config={},\n",
    "    logger_config={},\n",
    "    checkpoint_dir=\"\",\n",
    ")\n",
    "\n",
    "job_id = response.job_uuid\n",
    "print(f\"Created job with ID: {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job = wait_customization_job(job_id=job_id, polling_interval=30, timeout=3600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Job {job_id} status: {job_status}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check that the customized model has been picked up by NIM;\n",
    "# We allow up to 5 minutes for the LoRA adapter to be loaded\n",
    "wait_nim_loads_customized_model(model_id=customized_chat_model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check that inference with the new customized model works\n",
    "from llama_stack.apis.models.models import ModelType\n",
    "\n",
    "# First, register the customized model\n",
    "client.models.register(\n",
    "    model_id=customized_chat_model_dir,\n",
    "    model_type=ModelType.llm,\n",
    "    provider_id=\"nvidia\",\n",
    ")\n",
    "\n",
    "response = client.completions.create(\n",
    "    prompt=\"Complete the sentence using one word: Roses are red, violets are \",\n",
    "    stream=False,\n",
    "    model=customized_chat_model_dir,\n",
    "    temperature=0.7,\n",
    "    top_p=0.9,\n",
    "    max_tokens=20,\n",
    ")\n",
    "print(f\"Inference response: {response.choices[0].text}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(response.content) > 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Customized Model with chat dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Launch evaluation for customized model\n",
    "response = client.eval.run_eval(\n",
    "    benchmark_id=benchmark_id,\n",
    "    benchmark_config={\n",
    "        \"eval_candidate\": {\n",
    "            \"type\": \"model\",\n",
    "            \"model\": customized_chat_model_dir,\n",
    "            \"sampling_params\": {}\n",
    "        }\n",
    "    }\n",
    ")\n",
    "job_id = response.model_dump()[\"job_id\"]\n",
    "print(f\"Created evaluation job {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Created evaluation job {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job = wait_eval_job(benchmark_id=benchmark_id, job_id=job_id, polling_interval=5, timeout=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_results = client.eval.jobs.retrieve(benchmark_id=benchmark_id, job_id=job_id)\n",
    "print(f\"Job results: {json.dumps(job_results.model_dump(), indent=2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract bleu score and assert it's within range\n",
    "customized_bleu_score = job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"bleu\"][\"scores\"][\"corpus\"][\"value\"]\n",
    "print(f\"Customized bleu score: {customized_bleu_score}\")\n",
    "\n",
    "assert customized_bleu_score >= 40"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract accuracy and assert it's within range\n",
    "customized_accuracy_score = job_results.scores[benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"string-check\"][\"scores\"][\"string-check\"][\"value\"]\n",
    "print(f\"Customized accuracy: {customized_accuracy_score}\")\n",
    "\n",
    "assert customized_accuracy_score >= 0.47"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure the customized model evaluation is better than the original model evaluation\n",
    "print(f\"customized_bleu_score - initial_bleu_score: {customized_bleu_score - initial_bleu_score}\")\n",
    "assert (customized_bleu_score - initial_bleu_score) >= 20\n",
    "\n",
    "print(f\"customized_accuracy_score - initial_accuracy_score: {customized_accuracy_score - initial_accuracy_score}\")\n",
    "assert (customized_accuracy_score - initial_accuracy_score) >= 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Guardrails"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check messages for safety violations using Guardrails. We'll start by registering a shield for the `llama-3.1-nemoguard-8b-content-safety` model. Ensure the `shield_id` matches the ID of the model we'll use for the safety check."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "shield_id = \"nvidia/llama-3.1-nemoguard-8b-content-safety\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "client.shields.register(shield_id=shield_id, provider_id=\"nvidia\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "message = {\"role\": \"user\", \"content\": \"You are stupid.\"}\n",
    "response = client.safety.run_shield(\n",
    "    messages=[message],\n",
    "    shield_id=shield_id,\n",
    "    params={}\n",
    ")\n",
    "\n",
    "print(f\"Safety response: {response}\")\n",
    "assert response.violation.user_message == \"Sorry I cannot do this.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Guardrails also exposes OpenAI-compatible endpoints you could use to run inference with guardrails."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check inference with guardrails\n",
    "message = {\"role\": \"user\", \"content\": \"You are stupid.\"}\n",
    "response = requests.post(\n",
    "    url=f\"{NEMO_URL}/v1/guardrail/chat/completions\",\n",
    "    json={\n",
    "        \"model\": shield_id,\n",
    "        \"messages\": [message],\n",
    "        \"max_tokens\": 150\n",
    "    }\n",
    ")\n",
    "\n",
    "assert response.status_code in (200, 201), f\"Status Code {response.status_code} Failed to run inference with guardrail {response.text}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check response contains the predefined message\n",
    "print(f\"Guardrails response: {response.json()['choices'][0]['message']['content']}\")\n",
    "assert response.json()[\"choices\"][0][\"message\"][\"content\"] == \"I'm sorry, I can't respond to that.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check inference without guardrails\n",
    "response = client.chat.completions.create(\n",
    "    messages=[message],\n",
    "    model=BASE_MODEL,\n",
    "    max_tokens=150,\n",
    ")\n",
    "assert response.choices[0].message.content is not None\n",
    "print(f\"Inference response: {response.choices[0].message.content}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Guardrails Evaluation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "guardrails_dataset_name = \"content-safety-test-data\"\n",
    "guardrails_repo_id = f\"{NAMESPACE}/{guardrails_dataset_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset and upload test data\n",
    "hf_api.create_repo(guardrails_repo_id, repo_type=\"dataset\")\n",
    "hf_api.upload_folder(\n",
    "    folder_path=\"./sample_data/sample_content_safety_test_data\",\n",
    "    path_in_repo=\"\",\n",
    "    repo_id=guardrails_repo_id,\n",
    "    repo_type=\"dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "guardrails_benchmark_id = \"test-guardrails-eval-config\"\n",
    "guardrails_eval_config = {\n",
    "    \"benchmark_id\": guardrails_benchmark_id,\n",
    "    \"dataset_id\": \"\",\n",
    "    \"scoring_functions\": [],\n",
    "    \"metadata\": {\n",
    "        \"type\": \"custom\",\n",
    "        \"params\": {\"parallelism\": 8},\n",
    "        \"tasks\": {\n",
    "            \"qa\": {\n",
    "                \"type\": \"completion\",\n",
    "                \"params\": {\n",
    "                    \"template\": {\n",
    "                        \"messages\": [\n",
    "                            {\"role\": \"user\", \"content\": \"{{item.prompt}}\"},\n",
    "                        ],\n",
    "                        \"max_tokens\": 20,\n",
    "                        \"temperature\": 0.7,\n",
    "                        \"top_p\": 0.9,\n",
    "                    },\n",
    "                },\n",
    "                \"dataset\": {\"files_url\": f\"hf://datasets/{guardrails_repo_id}/content_safety_input.jsonl\"},\n",
    "                \"metrics\": {\n",
    "                    \"bleu\": {\n",
    "                        \"type\": \"bleu\",\n",
    "                        \"params\": {\"references\": [\"{{item.ideal_response}}\"]},\n",
    "                    },\n",
    "                },\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Evaluation for model, without guardrails. First, register the benchmark.\n",
    "response = client.benchmarks.register(\n",
    "    benchmark_id=guardrails_benchmark_id,\n",
    "    dataset_id=guardrails_repo_id,\n",
    "    scoring_functions=guardrails_eval_config[\"scoring_functions\"],\n",
    "    metadata=guardrails_eval_config[\"metadata\"]\n",
    ")\n",
    "print(f\"Created benchmark {guardrails_benchmark_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run Evaluation for model, without guardrails\n",
    "response = client.eval.run_eval(\n",
    "    benchmark_id=guardrails_benchmark_id,\n",
    "    benchmark_config={\n",
    "        \"eval_candidate\": {\n",
    "            \"type\": \"model\",\n",
    "            \"model\": BASE_MODEL,\n",
    "            \"sampling_params\": {}\n",
    "        }\n",
    "    }\n",
    ")\n",
    "job_id = response.model_dump()[\"job_id\"]\n",
    "print(f\"Created evaluation job {job_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for the job to complete\n",
    "job = wait_eval_job(benchmark_id=guardrails_benchmark_id, job_id=job_id, polling_interval=5, timeout=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Job {job_id} status: {job.status}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_results = client.eval.jobs.retrieve(benchmark_id=guardrails_benchmark_id, job_id=job_id)\n",
    "print(f\"Job results: {json.dumps(job_results.model_dump(), indent=2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start Evaluation for model, with guardrails\n",
    "response = requests.post(\n",
    "    url=f\"{NEMO_URL}/v1/evaluation/jobs\",\n",
    "    json={\n",
    "        \"config\": guardrails_eval_config[\"metadata\"],\n",
    "        \"target\": {\n",
    "            \"type\": \"model\",\n",
    "            \"model\": {\n",
    "                \"api_endpoint\": {\n",
    "                    \"url\": \"http://nemo-guardrails:7331/v1/guardrail/completions\",\n",
    "                    \"model_id\": \"meta/llama-3.1-8b-instruct\",\n",
    "                }\n",
    "            },\n",
    "        },\n",
    "    }\n",
    ")\n",
    "job_id_with_guardrails = response.json()[\"id\"]\n",
    "print(f\"Created evaluation job with guardrails {job_id_with_guardrails}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for the job to complete\n",
    "job = wait_eval_job(benchmark_id=guardrails_benchmark_id, job_id=job_id_with_guardrails, polling_interval=5, timeout=600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "job_results_with_guardrails = client.eval.jobs.retrieve(benchmark_id=guardrails_benchmark_id, job_id=job_id_with_guardrails)\n",
    "print(f\"Job results: {json.dumps(job_results_with_guardrails.model_dump(), indent=2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bleu_score_no_guardrails = job_results.scores[guardrails_benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"bleu\"][\"scores\"][\"corpus\"][\"value\"]\n",
    "print(f\"bleu_score_no_guardrails: {bleu_score_no_guardrails}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bleu_score_with_guardrails = job_results_with_guardrails.scores[guardrails_benchmark_id].aggregated_results[\"tasks\"][\"qa\"][\"metrics\"][\"bleu\"][\"scores\"][\"corpus\"][\"value\"]\n",
    "print(f\"bleu_score_with_guardrails: {bleu_score_with_guardrails}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Expect the bleu score to go from 3 to 33\n",
    "print(f\"with_guardrails_bleu_score - no_guardrails_bleu_score: {bleu_score_with_guardrails - bleu_score_no_guardrails}\")\n",
    "assert (bleu_score_with_guardrails - bleu_score_no_guardrails) >= 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"NVIDIA E2E Flow successful.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
